import math
import numpy as np
import random
import matplotlib.pyplot as plt

def pull_arm(mu, i, a):
    X = random.uniform(0, 1)
    if X < mu[i][a]:
        return 1
    else:
        return 0
def Confidence(t, n_i_k_t, N_i, beta_i):
    C = (1 + beta_i) * math.sqrt(3 * math.log(t) / (N_i*n_i_k_t)) + (1 / (2 * t))
    return C

def DUCB(N,M,T,mu,best_arm,W,beta,Neighbor):

    mu_star = np.mean(mu, axis=0)
    sum_reward = np.zeros((N,M))
    sum_pull_time = np.zeros((N,M))
    regret_list = [[0] for _ in range(N)]
    #print("start:",regret_list)

    z = np.zeros((N, M))
    hat_x = np.zeros((N, M,2))
    #z = mu
    #for i in range(N):
    #    for k in range(M):
    #        hat_x[i][k][0] = mu[i][k]

    retrain_time = 10
    z = np.zeros((N, M))
    hat_x = np.zeros((N, M,2))

    for train_time in range(retrain_time):
        for i in range(N):
            for k in range(M):
                reward = pull_arm(mu, i, k)
                hat_x[i][k][0] += reward
                hat_x[i][k][1] += 1
    for i in range(N):
       for k in range(M):
            z[i][k] = hat_x[i][k][0]/hat_x[i][k][1]
    # print(z)

    m = np.zeros((N, M))
    for i in range(N):
        for k in range(M):
            m[i][k] = retrain_time

    #for i in range(N):
    #    for k in range(M):
    #        sum_reward[i][k] += 1
    #        sum_pull_time[i][k] = np.zeros((N, M))
    #        regret_list = [[0], [0], [0]]
    #        m[i][k] = 1

    n = np.zeros((N, M))
    n = m

    C = 0

    for t in range(1,T):
        a = np.zeros(N)
        new_hat_x = hat_x
        for i in range(N):
            # line 3
            A = [set() for _ in range(N)]
            # A[0].add(1)
            # random_element = random.choice(list(A[0]))

            # line 4
            for k in range(M):
                if n[i][k] <= m[i][k] - M:
                    A[i].add(k)

            Q = np.zeros((N,M))
            if not A[i]:
                for k in range(M):
                    # print(len(Neighbor[i]))
                    Q[i][k] = z[i][k] + Confidence(t, n[i][k], len(Neighbor[i]), beta[i])
                a[i] = np.argmax(Q[i])
            else:
                a[i] = random.choice(list(A[i]))



            reward = pull_arm(mu,int(i),int(a[i]))
            new_hat_x[int(i)][int(a[i])][0] += reward
            new_hat_x[int(i)][int(a[i])][1] += 1
            sum_reward[i][int(a[i])] += reward
            sum_pull_time[i][int(a[i])] += 1
            regret_list[i].append((regret_list[i][-1] + mu_star[best_arm]-mu_star[int(a[i])]))


        new_z = np.zeros((N,M))
        for i in range(N):
            n[int(i)][int(a[i])] += 1
            for k in range(M):
                for nei in Neighbor[i]:
                    new_z[i][k] += W[i][nei] * z[nei][k]
                new_z[i][k] +=  (new_hat_x[i][k][0]/new_hat_x[i][k][1] - hat_x[i][k][0]/hat_x[i][k][1])
                #new_z[i][k] =  sum(W[i] * ((z.T)[k])) + new_hat_x[i][k][0]/new_hat_x[i][k][1] - hat_x[i][k][0]/hat_x[i][k][1]
            for k in range(M):
                for j in (Neighbor[i]):
                    m[i][k] = max(n[i][k],m[j][k],m[i][k])

        hat_x = new_hat_x
        z = new_z


    Regret = []
    for agent in range(N):
        reward = 0
        for arm in range(M):
            reward += hat_x[agent][arm][0]
        best_reward = T*mu[agent][best_arm]
        Regret.append((best_reward-reward))
    result_x = np.zeros((M,2))
    for agent in range(N):
        result_x += hat_x[agent]
    result_mu = np.zeros(M)
    for arm in range(M):
        result_mu[arm] = result_x[arm][0]/result_x[arm][1]
    return regret_list

def Distri():
    #N = 8 # agent
    M = 20 # arm
    W = np.array([[0.2, 0.2, 0.2, 0, 0, 0, 0.2, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0, 0, 0, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0.2, 0, 0, 0],
                  [0, 0.2, 0.2, 0.2, 0.2, 0.2, 0, 0],
                  [0, 0, 0.2, 0.2, 0.2, 0.2, 0.2, 0],
                  [0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0, 0, 0, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0.2, 0, 0, 0, 0.2, 0.2, 0.2]])


    # Create the matrix using list comprehension
    mu5 = [[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(5)]
    mu8 = [[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(8)]
    mu11 = [[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(11)]
    mu14 = [[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(14)]
    mu17 = [[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(17)]
    mu20 = [[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(20)]

    Neighbor5 = [[1 for _ in range(5)] for _ in range(5)]
    Neighbor8 = [[1 for _ in range(8)] for _ in range(8)]
    Neighbor11 = [[1 for _ in range(11)] for _ in range(11)]
    Neighbor14 = [[1 for _ in range(14)] for _ in range(14)]
    Neighbor17 = [[1 for _ in range(17)] for _ in range(17)]
    Neighbor20 = [[1 for _ in range(20)] for _ in range(20)]

    W5 = [[1 / 5 for _ in range(5)] for _ in range(5)]
    W8 = [[1 / 8 for _ in range(8)] for _ in range(8)]
    W11 = [[1 / 11 for _ in range(11)] for _ in range(11)]
    W14 = [[1 / 14 for _ in range(14)] for _ in range(14)]
    W17 = [[1 / 17 for _ in range(17)] for _ in range(17)]
    W20 = [[1 / 20 for _ in range(20)] for _ in range(20)]



    T = int(1e4)  # time
    best_arm = 0
    # print(mu_star)

    beta = [1]
    #Neighbor = [4,4,4,4,4,4,4,4]
    Neighbor = np.array([[1, 2, 6, 7],
              [0, 2, 3, 7],
              [0, 1, 3, 4],
              [1, 2, 4, 5],
              [2, 3, 5, 6],
              [3, 4, 6, 7],
              [0, 4, 5, 7],
              [0, 1, 5, 6]])
    Neighbor = np.array([[1, 2, 6, 7, 0],
                         [0, 2, 3, 7, 1],
                         [0, 1, 3, 4, 2],
                         [1, 2, 4, 5, 3],
                         [2, 3, 5, 6, 4],
                         [3, 4, 6, 7, 5],
                         [0, 4, 5, 7, 6],
                         [0, 1, 5, 6, 7]])

    Neighbor_x = np.array([[1, 1, 1, 0, 0, 0, 1, 1],
                  [1, 1, 1, 1, 0, 0, 0, 1],
                  [1, 1, 1, 1, 1, 0, 0, 0],
                  [0, 1, 1, 1, 1, 1, 0, 0],
                  [0, 0, 1, 1, 1, 1, 1, 0],
                  [0, 0, 0, 1, 1, 1, 1, 1],
                  [1, 0, 0, 0, 1, 1, 1, 1],
                  [1, 1, 0, 0, 0, 1, 1, 1]])

    repeated_time = 50


    regret_list_5 = []
    for repeat_time in range(repeated_time):
        temp_regret_list = DUCB(5, M, T, mu5, best_arm, W5, beta*5, Neighbor5)
        regret_list_5.append(temp_regret_list[0][-1])

    regret_list_8 = []
    for repeat_time in range(repeated_time):
        temp_regret_list = DUCB(8, M, T, mu8, best_arm, W8, beta*8, Neighbor8)
        regret_list_8.append(temp_regret_list[0][-1])

    regret_list_11 = []
    for repeat_time in range(repeated_time):
        temp_regret_list = DUCB(11, M, T, mu11, best_arm, W11, beta*11, Neighbor11)
        regret_list_11.append(temp_regret_list[0][-1])

    regret_list_14 = []
    for repeat_time in range(repeated_time):
        temp_regret_list = DUCB(14, M, T, mu14, best_arm, W14, beta*14, Neighbor14)
        regret_list_14.append(temp_regret_list[0][-1])

    regret_list_17 = []
    for repeat_time in range(repeated_time):
        temp_regret_list = DUCB(17, M, T, mu17, best_arm, W17, beta*17, Neighbor17)
        regret_list_17.append(temp_regret_list[0][-1])

    regret_list_20 = []
    for repeat_time in range(repeated_time):
        temp_regret_list = DUCB(20, M, T, mu20, best_arm, W20, beta*20, Neighbor20)
        regret_list_20.append(temp_regret_list[0][-1])

    regret_list_zero_np_5 = np.array(regret_list_5)
    regret_mean_5 = np.mean(regret_list_zero_np_5)
    regret_std_5 = np.std(regret_list_zero_np_5)

    regret_list_zero_np_8 = np.array(regret_list_8)
    regret_mean_8 = np.mean(regret_list_zero_np_8)
    regret_std_8 = np.std(regret_list_zero_np_8)

    regret_list_zero_np_11 = np.array(regret_list_11)
    regret_mean_11 = np.mean(regret_list_zero_np_11)
    regret_std_11 = np.std(regret_list_zero_np_11)

    regret_list_zero_np_14 = np.array(regret_list_14)
    regret_mean_14 = np.mean(regret_list_zero_np_14)
    regret_std_14 = np.std(regret_list_zero_np_14)

    regret_list_zero_np_17 = np.array(regret_list_17)
    regret_mean_17 = np.mean(regret_list_zero_np_17)
    regret_std_17 = np.std(regret_list_zero_np_17)

    regret_list_zero_np_20 = np.array(regret_list_20)
    regret_mean_20 = np.mean(regret_list_zero_np_20)
    regret_std_20 = np.std(regret_list_zero_np_20)

    regret_DIST = [regret_mean_5, regret_mean_8, regret_mean_11, regret_mean_14, regret_mean_17,regret_mean_20]
    #plt.plot([8,10,12,14,16], regret_DIST, linestyle='--', color='pink', marker='v', markerfacecolor='none', markersize=10)

    #plt.fill_between([8,10,12,14,16], [regret_mean_8-regret_std_8,regret_mean_10-regret_std_10,regret_mean_12-regret_std_12,regret_mean_14-regret_std_14,regret_mean_16-regret_std_16], [regret_mean_8+regret_std_8,regret_mean_10+regret_std_10,regret_mean_12+regret_std_12,regret_mean_14+regret_std_14,regret_mean_16+regret_std_16],color='LightPink', alpha=0.2)

    #plt.title('regret')
    #plt.xlabel('time')
    #plt.ylabel('value')
    #plt.legend()
    #plt.show()

    return (regret_DIST,[regret_mean_5-regret_std_5,regret_mean_8-regret_std_8,regret_mean_11-regret_std_11,regret_mean_14-regret_std_14,regret_mean_17-regret_std_17,regret_mean_20-regret_std_20], [regret_mean_5+regret_std_5,regret_mean_8+regret_std_8,regret_mean_11+regret_std_11,regret_mean_14+regret_std_14,regret_mean_17+regret_std_17,regret_mean_20+regret_std_20])